Skip to content

Conversation

@radka-j
Copy link
Member

@radka-j radka-j commented Oct 6, 2025

Closes #748
Closes #874
Closes #878
Closes #757

This PR:

  • adds fit_from_reinitialised method that is used both in AutoEmulate.compare and HMW.refit_emulator
  • emulators now save all their input args so that all input values can be retrieved
  • replaces any **kwargs in emulators with scheduler_kwargs optional keyword argument to match use
  • update HMW so that user can pass emulator as well as result
  • updates Emulator.fit to handle InputLike instead of expecting only TensorLike
  • updates AL to except emulator predictions to be DistributionLike rather than GaussianLike to match TransformedEmulator prediction types
  • updates AL to use fit_from_reinitialized

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Base automatically changed from iss867/update_gp_factory to main October 6, 2025 13:19
@sgreenbury
Copy link
Collaborator

Just adding a note here as ran into this when working with a GP subclass for the error quantification. This call:

model_class = get_emulator_class(result.model_name)

fails since:
emulator_cls = EMULATOR_REGISTRY.get(
name.lower()
) or EMULATOR_REGISTRY_SHORT_NAME.get(name.lower())

doesn't also look at:
GP_REGISTRY = {
"GaussianProcess": GaussianProcess,
"GaussianProcessCorrelated": GaussianProcessCorrelated,
}

@radka-j - adding here as it might be addressed by the upcoming changes to this API? But if not happy to open a new issue to look at this. An option could also be to revisit having a central registry class to handle this uniformly.

@radka-j
Copy link
Member Author

radka-j commented Oct 6, 2025

@sgreenbury I don't think we should ever use the GaussianProcess or GaussianProcessCorrelated classes so this to me feels like correct behaviour. If we want a GP class for an RBF + constant kernel we should add that specifically to the registry.

@sgreenbury
Copy link
Collaborator

It was the GP context (passing a create_gp_subclass instance to AutoEmulate) I ran into this issue and a workaround might have been to also look at GP_REGISTRY since this maintains a registry of all GPs including the created subclasses.

But thinking more about it, it affects any subclass used by AutoEmulate currently if reinitialize is called, e.g. in the advanced tutorial:

class SimpleFNN(PyTorchBackend):
    ...
ae = AutoEmulate(x, y, models=[SimpleFNN])
ae.fit_from_reinitialized(x, y)

since SimpleFNN is constructed at runtime the class is not found in the lists of emulators.

I think if the emulator becomes the entity that does the refitting in this PR then a global emulator registry including all custom subclasses would not be needed for this but might still be useful?

… GaussianLike to match TranformedEmulator predict type
@radka-j
Copy link
Member Author

radka-j commented Oct 13, 2025

The lodget, trace and max_eigval plots in the AL documentation look wrong after the refactor here (they barely change). I started trying to figure out what's happening and have a sense that the predicted uncertainty is narrowed when using a GP wrapped inside a TransformedEmulator (even without any transforms) vs just a GP. I need to investigate this more formally but we need to understand what's happening before we can merge this.

@radka-j
Copy link
Member Author

radka-j commented Oct 14, 2025

I don't know what the issue is yet but my previous comment about the uncertainty from TransformedEmulator being narrower was wrong. I was comparing GP vs TransformedEmulator with GP using different learning rates. Once the same learning rate was used they look visually identical.

@sgreenbury
Copy link
Collaborator

It might be related to whether posterior_predictive=True is being passed to the reinitialized GP when within the TransformedEmulator?

For example, on main in the dim reduction tutorial:
https://github.com/alan-turing-institute/autoemulate/blob/6d4a92fdcb2614b5dee5f907855e7003503c0910/docs/tutorials/emulation/02_dim_reduction.ipynb

em = ae.fit_from_reinitialized(x[train_idx], y[train_idx])

has:

print(em.model.posterior_predictive)
False

though the original AutoEmulate initialization having posterior_predictive=True.

@radka-j
Copy link
Member Author

radka-j commented Oct 14, 2025

Thank you for checking! In this case the posterior_predictive is correctly set to True after the emulator is re-initialized each time.

@radka-j
Copy link
Member Author

radka-j commented Oct 14, 2025

@sgreenbury I'm also not sure if you saw my previous comment but the uncertainty output from TransformedEmulator seems to be fine.

@radka-j
Copy link
Member Author

radka-j commented Oct 14, 2025

@sgreenbury I tried running the AL notebook using a GP wrapped inside a TransformedEmulator but calling emulator.fit instead of fit_from_reinitialized as originally implemented and the results look the very similar to the current docs. So it looks like the issue comes from re-initializing the emulator. Given the GP is refitting 1 data point at a time, this might be a case where calling fit with the hyperparameters fixed might actually make sense.

I therefore decided to revert this change and leave AL as is in this PR (only updating typing). We can separately decide whether to leave the associated issue (#757) open to revisit at some later point or close.

Copy link
Collaborator

@sgreenbury sgreenbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks @radka-j! As we discussed:

  • it looks like choice of emulator reinitialization could be good to have in the API
  • there's an issue following our discussion capturing revisiting the overall workflow (#893)
  • the dimensionality reduction tutorial seems to not pick up the model_params={"posterior_predictive": True}

There is the comment above about DistributionLike not always having mean/variance - I don't think we'll run into this currently but might be good to either restrict here with the instance matching or have an issue for it.

Otherwise looks good to merge!

@radka-j radka-j merged commit 3a63ee4 into main Oct 15, 2025
5 checks passed
@radka-j radka-j deleted the reinitialise branch October 15, 2025 10:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

2 participants